{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████| 235/235 [00:12<00:00, 19.03it/s]\n",
      "  1%|▋                                                                                 | 2/235 [00:00<00:14, 15.67it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 1, loss 0.8011\n",
      "train accuracy 0.741, test accuracy 0.786\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████| 235/235 [00:17<00:00, 13.27it/s]\n",
      "  1%|▋                                                                                 | 2/235 [00:00<00:17, 13.46it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 2, loss 0.5776\n",
      "train accuracy 0.809, test accuracy 0.808\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████| 235/235 [00:22<00:00, 10.68it/s]\n",
      "  0%|                                                                                          | 0/235 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 3, loss 0.5339\n",
      "train accuracy 0.820, test accuracy 0.817\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████| 235/235 [00:21<00:00, 10.87it/s]\n",
      "  1%|▋                                                                                 | 2/235 [00:00<00:19, 12.15it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 4, loss 0.5082\n",
      "train accuracy 0.829, test accuracy 0.820\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████| 235/235 [00:15<00:00, 15.44it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 5, loss 0.4915\n",
      "train accuracy 0.833, test accuracy 0.826\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from torch import tensor\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "\n",
    "# 1、加载Fashion-MNIST数据集（采用已划分好的训练集和测试集）\n",
    "mnist_train=torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST',train=True,download=True,transform=transforms.ToTensor())\n",
    "mnist_test=torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST',train=False,download=True,transform=transforms.ToTensor())\n",
    "\n",
    "# 2、通过Dataloader读取小批量数据样本\n",
    "batch_size=256\n",
    "train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=0)\n",
    "test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=0)\n",
    "\n",
    "# 3、手动构建模型\n",
    "num_inputs = 784\n",
    "num_outputs = 10  # 共10类\n",
    "\n",
    "# 3.1参数初始化\n",
    "W = torch.normal(0, 0.1, (num_inputs, num_outputs), dtype=torch.float32)  # 784*10\n",
    "b = torch.normal(0, 0.01, (1, num_outputs), dtype=torch.float32)  # 偏差参数1*10\n",
    "# 模型的参数梯度\n",
    "W.requires_grad_(requires_grad=True)\n",
    "b.requires_grad_(requires_grad=True)\n",
    "\n",
    "# 3.2 softmax回归模型\n",
    "def softmax(X):  # softmax计算\n",
    "    X_exp = X.exp()  # 对每个元素做指数运算\n",
    "    partition = X_exp.sum(dim=1, keepdim=True)  # 求列和，即对同行元素求和 n*1\n",
    "    return X_exp / partition  # broadcast\n",
    "def net(X):\n",
    "    return softmax(torch.mm(X.view((-1, num_inputs)), W) + b)\n",
    "\n",
    "#3.3 交叉熵损失函数\n",
    "def loss(y_hat, y):\n",
    "    return - torch.log(y_hat.gather(1, y.view(-1, 1)))\n",
    "\n",
    "# 3.4 优化器\n",
    "def sgd(params, lr, batch_size):\n",
    "    for param in params:\n",
    "        param.data -= lr * param.grad / batch_size  # 注意这里更改param时用的param.data\n",
    "\n",
    "# 4、计算分类准确率\n",
    "def evaluate_accurcy(data_iter, net):\n",
    "    right_count, all_num = 0.0, 0\n",
    "    for x, y in data_iter:\n",
    "        right_count += (net(x).argmax(dim=1) == y).float().sum().item()\n",
    "        all_num += y.shape[0]\n",
    "    return right_count / all_num\n",
    "\n",
    "# argmax()和argmin()函数可以寻找向量所在的最小值和最大值的下标，dim选择查找的维度\n",
    "# 5、模型训练\n",
    "lr = 0.1\n",
    "num_epochs = 5\n",
    "for epoch in range(num_epochs):\n",
    "    train_right_sum, train_all_sum, train_loss_sum = 0.0, 0, 0.0\n",
    "    for X, y in tqdm(train_iter):  # tqdm显示训练进度条\n",
    "        y_hat = net(X)\n",
    "        l = loss(y_hat, y).sum()  # 计算loss\n",
    "        l.backward()  # 求梯度\n",
    "        sgd([W, b], lr, batch_size)  # 参数更新\n",
    "        W.grad.data.zero_()\n",
    "        b.grad.data.zero_()  # 梯度清零\n",
    "        train_loss_sum += l.item()  # 损失\n",
    "        train_right_sum += (y_hat.argmax(dim=1) == y).sum().item()  # 训练集准确率\n",
    "        train_all_sum += y.shape[0]\n",
    "    test_acc = evaluate_accurcy(test_iter, net)  # 测试集准确率\n",
    "    print('epoch %d, loss %.4f' % (epoch+1, train_loss_sum/train_all_sum))\n",
    "    print('train accuracy %.3f, test accuracy %.3f' % (train_right_sum/train_all_sum, test_acc))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
